{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# 编码器-解码器架构\n",
    ":label:`sec_encoder-decoder`\n",
    "\n",
    "正如我们在 :numref:`sec_machine_translation`中所讨论的，\n",
    "机器翻译是序列转换模型的一个核心问题，\n",
    "其输入和输出都是长度可变的序列。\n",
    "为了处理这种类型的输入和输出，\n",
    "我们可以设计一个包含两个主要组件的架构：\n",
    "第一个组件是一个*编码器*（encoder）：\n",
    "它接受一个长度可变的序列作为输入，\n",
    "并将其转换为具有固定形状的编码状态。\n",
    "第二个组件是*解码器*（decoder）：\n",
    "它将固定形状的编码状态映射到长度可变的序列。\n",
    "这被称为*编码器-解码器*（encoder-decoder）架构，\n",
    "如 :numref:`fig_encoder_decoder` 所示。\n",
    "\n",
    "![编码器-解码器架构](https://d2l.ai/_images/encoder-decoder.svg)\n",
    ":label:`fig_encoder_decoder`\n",
    "\n",
    "我们以英语到法语的机器翻译为例：\n",
    "给定一个英文的输入序列：“They”、“are”、“watching”、“.”。\n",
    "首先，这种“编码器－解码器”架构将长度可变的输入序列编码成一个“状态”，\n",
    "然后对该状态进行解码，\n",
    "一个词元接着一个词元地生成翻译后的序列作为输出：\n",
    "“Ils”、“regordent”、“.”。\n",
    "由于“编码器－解码器”架构是形成后续章节中不同序列转换模型的基础，\n",
    "因此本节将把这个架构转换为接口方便后面的代码实现。\n",
    "\n",
    "## 编码器\n",
    "\n",
    "在编码器接口中，我们只指定长度可变的序列作为编码器的输入`X`。\n",
    "任何继承这个`Encoder` 基类的模型将完成代码实现。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 1,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "public abstract class Encoder extends AbstractBlock {\n",
    "\n",
    "    /* The base encoder interface for the encoder-decoder architecture. */\n",
    "    private static final byte VERSION = 1;\n",
    "\n",
    "    public Encoder() {\n",
    "        super(VERSION);\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    abstract protected NDList forwardInternal(\n",
    "            ParameterStore parameterStore,\n",
    "            NDList inputs,\n",
    "            boolean training,\n",
    "            PairList<String, Object> params);\n",
    "\n",
    "    @Override\n",
    "    public Shape[] getOutputShapes(Shape[] inputShapes) {\n",
    "        throw new UnsupportedOperationException(\"Not implemented\");\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 4
   },
   "source": [
    "## 解码器\n",
    "\n",
    "在下面的解码器接口中，我们新增一个`initState()`函数，\n",
    "用于将编码器的输出（`encOutputs`）转换为编码后的状态。\n",
    "注意，此步骤可能需要额外的输入，例如：输入序列的有效长度，\n",
    "这在 :numref:`subsec_mt_data_loading` 中进行了解释。\n",
    "为了逐个地生成长度可变的词元序列，\n",
    "解码器在每个时间步都会将输入\n",
    "（例如：在前一时间步生成的词元）和编码后的状态\n",
    "映射成当前时间步的输出词元。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 5,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "public abstract class Decoder extends AbstractBlock {\n",
    "\n",
    "    /* The base decoder interface for the encoder-decoder architecture. */\n",
    "    private static final byte VERSION = 1;\n",
    "\n",
    "    public NDArray attentionWeights;\n",
    "\n",
    "    public Decoder() {\n",
    "        super(VERSION);\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    abstract protected NDList forwardInternal(\n",
    "            ParameterStore parameterStore,\n",
    "            NDList inputs,\n",
    "            boolean training,\n",
    "            PairList<String, Object> params);\n",
    "\n",
    "    abstract public NDList initState(NDList encOutputs);\n",
    "\n",
    "    @Override\n",
    "    public Shape[] getOutputShapes(Shape[] inputShapes) {\n",
    "        throw new UnsupportedOperationException(\"Not implemented\");\n",
    "    }\n",
    "}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 8
   },
   "source": [
    "## [**合并编码器和解码器**]\n",
    "\n",
    "总而言之，“编码器-解码器”架构包含了一个编码器和一个解码器，\n",
    "并且还拥有可选的额外的参数。\n",
    "在前向传播中，编码器的输出用于生成编码状态，\n",
    "这个状态又被解码器作为其输入的一部分。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 9,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "public class EncoderDecoder extends AbstractBlock {\n",
    "\n",
    "    /* The base class for the encoder-decoder architecture. */\n",
    "    private static final byte VERSION = 1;\n",
    "\n",
    "    public Encoder encoder;\n",
    "    public Decoder decoder;\n",
    "\n",
    "    public EncoderDecoder(Encoder encoder, Decoder decoder) {\n",
    "        super(VERSION);\n",
    "\n",
    "        this.encoder = encoder;\n",
    "        this.addChildBlock(\"encoder\", encoder);\n",
    "        this.decoder = decoder;\n",
    "        this.addChildBlock(\"decoder\", decoder);\n",
    "    }\n",
    "\n",
    "    /** {@inheritDoc} */\n",
    "    @Override\n",
    "    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    protected NDList forwardInternal(ParameterStore parameterStore,\n",
    "                                     NDList inputs, boolean training,\n",
    "                                     PairList<String, Object> params) {\n",
    "        NDArray encX = inputs.get(0);\n",
    "        NDArray decX = inputs.get(1);\n",
    "        NDList encOutputs = encoder.forward(parameterStore, new NDList(encX), training, params);\n",
    "        NDList decState = decoder.initState(encOutputs);\n",
    "        return decoder.forward(parameterStore, new NDList(decX).addAll(decState), training, params);\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public Shape[] getOutputShapes(Shape[] inputShapes) {\n",
    "        throw new UnsupportedOperationException(\"Not implemented\");\n",
    "    }\n",
    "}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 12
   },
   "source": [
    "“编码器－解码器”体系架构中的术语“状态”\n",
    "可能会启发你使用具有状态的神经网络来实现该架构。\n",
    "在下一节中，我们将学习如何应用循环神经网络，\n",
    "来设计基于“编码器－解码器”架构的序列转换模型。\n",
    "\n",
    "## 小结\n",
    "\n",
    "* “编码器－解码器”架构可以将长度可变的序列作为输入和输出，因此适用于机器翻译等序列转换问题。\n",
    "* 编码器将长度可变的序列作为输入，并将其转换为具有固定形状的编码状态。\n",
    "* 解码器将具有固定形状的编码状态映射为长度可变的序列。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 假设我们使用神经网络来实现“编码器－解码器”架构，那么编码器和解码器必须是同一类型的神经网络吗？\n",
    "1. 除了机器翻译，你能想到其它可以适用于”编码器－解码器“架构的应用吗？\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
